{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "04192.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/04192.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YmFH76O2iOH6",
        "colab_type": "text"
      },
      "source": [
        "## 5.3 多输入通道和多输出通道"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tr8x37dmij-e",
        "colab_type": "text"
      },
      "source": [
        "### 5.3.1 多输入通道"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_7-OtiY3ioRI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch \n",
        "from torch import nn \n",
        "import d2l \n",
        "\n",
        "\n",
        "def corr2d_multi_in(X, K):\n",
        "    res = d2l.corr2d(X[0, :, :], K[0, :, :])\n",
        "    for i in range(1, X.shape[0]):\n",
        "        res += d2l.corr2d(X[i, :, :], K[i, :, :])\n",
        "    return res "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zXyNoJJs1p3X",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "bca3bd30-d522-4776-86f6-d877d8eed576"
      },
      "source": [
        "X = torch.tensor([[[0, 1, 2], [3, 4, 5], [6, 7, 8]], \n",
        "          [[1, 2, 3], [4, 5, 6], [7, 8, 9]]])\n",
        "K = torch.tensor([[[0, 1], [2, 3]], [[1, 2], [3, 4]]])\n",
        "\n",
        "corr2d_multi_in(X, K)"
      ],
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[ 56.,  72.],\n",
              "        [104., 120.]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 24
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x4Fo2Kcw2Djx",
        "colab_type": "text"
      },
      "source": [
        "### 5.3.2 多输出通道"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zKDDaHj92JYA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def corr2d_multi_in_out(X, K):\n",
        "    return torch.stack([corr2d_multi_in(X, k) for k in K])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vb38hGVq2Whv",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "93e87bbf-8d4f-4ff2-9372-9f586e66af98"
      },
      "source": [
        "K = torch.stack([K, K + 1, K + 2])\n",
        "K.shape"
      ],
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([3, 2, 2, 2])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 26
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qb1_EoRf2cKf",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 151
        },
        "outputId": "10adbbb6-e3d6-40f1-8f47-0ad960e087aa"
      },
      "source": [
        "corr2d_multi_in_out(X, K)"
      ],
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[[ 56.,  72.],\n",
              "         [104., 120.]],\n",
              "\n",
              "        [[ 76., 100.],\n",
              "         [148., 172.]],\n",
              "\n",
              "        [[ 96., 128.],\n",
              "         [192., 224.]]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OIv-qg2y2gvH",
        "colab_type": "text"
      },
      "source": [
        "### 5.3.3 1 X 1 卷积层"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xJZAsTNS477F",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def corr2d_multi_in_out_1x1(X, K):\n",
        "    c_i, h, w = X.shape \n",
        "    c_o = K.shape[0]\n",
        "    X = X.view(c_i, h * w)\n",
        "    K = K.view(c_o, c_i)\n",
        "    Y = torch.mm(K, X)\n",
        "    return Y.view(c_o, h, w)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3N1odqec5ZGB",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "7ee32c19-1735-4731-9179-82b447f6f679"
      },
      "source": [
        "X = torch.rand(3, 3, 3)\n",
        "K = torch.rand(2, 3, 1, 1)\n",
        "\n",
        "Y1 = corr2d_multi_in_out_1x1(X, K)\n",
        "Y2 = corr2d_multi_in_out(X, K)\n",
        "\n",
        "(Y2 - Y1).norm().item() < 1e-6"
      ],
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "True"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 29
        }
      ]
    }
  ]
}